
import math
import logging
import torch as th
import torch.nn as nn
from torch.nn import functional as F
from utils.embed import polynomial_embed, binary_embed
import numpy as np

logger = logging.getLogger(__name__)
   

class GELU(nn.Module):
    def forward(self, input):
        return F.gelu(input)


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()
        assert T >= 0, "T must be non-negative"
        try:
            mask = th.tril(th.ones(T + 1, T + 1, device='cuda'))
            self.register_buffer("mask", mask.view(1, 1, T + 1, T + 1))
        except RuntimeError as e:
            print("Error occurred:", e)
            print(th.cuda.memory_summary())

        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C)  

        y = self.resid_drop(self.proj(y))
        return y


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 2 * config.n_embd),
            GELU(),
            nn.Linear(2 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class ObsEncoder4DT(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, decomposer, args):
        super(ObsEncoder4DT, self).__init__()
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.args = args
        
        self.skill_dim = args.skill_dim

        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        obs_own_dim = decomposer.own_obs_dim
        obs_en_dim, obs_al_dim = decomposer.obs_nf_en, decomposer.obs_nf_al
        n_actions_no_attack = decomposer.n_actions_no_attack
        
        n_skills = self.args.skill_dim
        wrapped_obs_own_dim = obs_own_dim + args.id_length
        
        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim).to('cuda')
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim).to('cuda')
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim).to('cuda')
        

    def init_hidden(self):
        return self.skill_logits.weight.new(1, self.entity_embed_dim).zero_()
   
    def forward(self, inputs, task):
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]
                                                          
        obs_dim = task_decomposer.obs_dim
        obs_inputs, agent_id_inputs = inputs[:, :obs_dim],inputs[:, obs_dim:]

        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  
        bs = int(own_obs.shape[0] / task_n_agents)

        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs, 1).to(own_obs.device)

        own_obs = th.cat([own_obs, agent_id_inputs], dim=-1)

        enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)

        total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)

        return total_hidden

class GPT(nn.Module):
    """  the full GPT language model, with a context size of block_size """
    def __init__(self, task2input_shape_info, config, task2decomposer, task2n_agents, decomposer):
        super().__init__()

        self.config = config
        self.n_layer = config.n_layer
        self.n_head = config.n_head
        self.block_size = config.block_size
        self.n_embd = config.n_embd
        self.model_type = config.model_type

        self.obs_encoder4dt = ObsEncoder4DT(task2input_shape_info,
                                            task2decomposer,
                                            task2n_agents,
                                            decomposer,
                                            config)
        
        self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh()).to('cuda')
        self.skill_emb = nn.Sequential(nn.Linear(config.skill_dim, config.n_embd), nn.Tanh()).to('cuda')
        self.gl_skill_emb = nn.Sequential(nn.Linear(config.vqvae_D, config.n_embd), nn.Tanh()).to('cuda')

        self.block_size = {k: self._return_block_size(task2decomposer[k]) for k in task2input_shape_info.keys()}
        self.position_embed = nn.Embedding(200, config.n_embd).to('cuda')
 
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]).to('cuda')
 
        self.ln_f = nn.LayerNorm(config.n_embd).to('cuda')
        self.skill_head = nn.Linear(config.n_embd, config.skill_dim, bias=False).to('cuda')
        self.gl_skill_head = nn.Linear(config.n_embd, config.vqvae_K, bias=False).to('cuda') 
 
        
        self.attention4obs_tokens = nn.MultiheadAttention(embed_dim=config.n_embd, num_heads=2).to('cuda')
        
        
        _dataset_folder = config.offline_data_folder4_gl_skill+"_debug" if config.debug else config.offline_data_folder4_gl_skill
        _task = config.train_tasks[0]

        codebook = np.load(_dataset_folder + "/" + _task +"/" + config.train_tasks_data_quality[_task]+"/gl_skill_dim_"+\
            str(config.vqvae_K)+"/"+ config.train_global_skill_time + "/" + str(config.offline_data_name) +"/codebook.npy")
        codebook = th.from_numpy(codebook).to("cuda")  # shape: [K=32, 64]
        self.codebook = codebook
        
        logger.info("number of DT's parameters: %e", sum(p.numel() for p in self.parameters()))
        
    def _return_block_size(self, task2decomposer):
        return task2decomposer.n_agents + task2decomposer.n_enemies + 2
    
    def zero_grad(self):
        pass

    
    def _attention_4_obs_embedding(self, obs_token_embeddings):
        batch_size, timesteps, agent_num, num_tokens, n_embd = obs_token_embeddings.shape

        obs_token_embeddings = obs_token_embeddings.view(-1, num_tokens, n_embd)
        
        obs_token_embeddings = obs_token_embeddings.transpose(0, 1)
        
        attn_output, _ = self.attention4obs_tokens(obs_token_embeddings, obs_token_embeddings, obs_token_embeddings)
        
        first_token_embedding = attn_output[0:1]
        
        first_token_embedding = first_token_embedding.view(batch_size, timesteps, agent_num, 1, n_embd)
        
        return first_token_embedding
    
    def forward(self, input, task):
        #*  basic info
        obs_input, sk, gl_sk_id, rtg, step_start, step_end = input
        bs = sk.shape[0]
        context_length = sk.shape[1]
        n_agent = sk.shape[2]
        
        #* obs_token_embeddings
        obs_token_embeddings = self.obs_encoder4dt(obs_input, task)                                                 
        obs_token_embeddings = obs_token_embeddings.reshape(bs, step_end-step_start, n_agent,-1, self.config.n_embd) 

        #* rtg_token_embeddings & skill_token_embeddings
        if self.model_type == 'rtgs_state_action':
            rtg_embeddings = self.ret_emb(rtg.type(th.float32)).unsqueeze(-2)                            
            sk_embeddings = self.skill_emb(sk.type(th.float32)).unsqueeze(-2)                           
            gl_sk_embeddings = self.codebook[gl_sk_id.type(th.int64)].type(th.float32)             

            token_embeddings = th.cat((rtg_embeddings, obs_token_embeddings, gl_sk_embeddings, sk_embeddings), dim=-2)
        
        #* position embeddings
        t_tensor = th.zeros_like(rtg[:,:,0,:])
        t_tensor[:,:,0] = th.arange(step_start,step_end)
        position_embeddings = self.position_embed(t_tensor.long())\
            .unsqueeze(2).repeat(1,1, n_agent,token_embeddings.shape[-2],1) 
        
        #* input for DT
        if self.config.step_block_size == 1:
            x = self.blocks((token_embeddings + position_embeddings).reshape(-1,token_embeddings.shape[-2], token_embeddings.shape[-1]))  
            x = self.ln_f(x)                                                                                                              
            logits = self.head(x)                                                                                                         
            skill = logits[:, -2, :]                                                                                                      

        elif self.config.step_block_size > 1:
            stacked_embeddings = (token_embeddings + position_embeddings).permute(0, 2, 1, 3, 4)     
            stacked_embeddings = stacked_embeddings\
                .reshape(bs*n_agent, context_length*stacked_embeddings.shape[-2], self.config.n_embd) 
            x = self.blocks(stacked_embeddings)                                                       
            x = self.ln_f(x)                                                                          

            x = x.reshape(bs, n_agent, context_length, token_embeddings.shape[-2], self.config.n_embd)\
                                        .permute(0, 2, 1, 3, 4)                                       
                                        
            skill = self.skill_head(x[:,:,:,-2,:].reshape(bs*context_length*n_agent, -1))            
            

            gl_skill_id = self.gl_skill_head(x[:,:,:,-3,:].reshape(bs*context_length*n_agent, -1))    
                          
        return gl_skill_id, skill                                                                      

       
    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config, lr):
       
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (th.nn.Linear, th.nn.Conv2d)
        blacklist_weight_modules = (th.nn.LayerNorm, th.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')
        no_decay.add('global_pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert len(
            param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params),)

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = th.optim.AdamW(optim_groups, lr=lr, betas=train_config.betas)
        return optimizer
